{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "\n",
    "\n",
    "data_path = '../data/p1ch7/'\n",
    "\n",
    "transformed_cifar10 = datasets.CIFAR10(data_path,train = True,download=False,\n",
    "                                       transform=transforms.Compose([\n",
    "                                           transforms.ToTensor(),\n",
    "                                           transforms.Normalize((0.4915, 0.4823, 0.4468),\n",
    "                                                                (0.2470, 0.2435, 0.2616))\n",
    "                                       ]))\n",
    "\n",
    "transformed_cifar10_val = datasets.CIFAR10(data_path,train = False,download=False,\n",
    "                                       transform=transforms.Compose([\n",
    "                                           transforms.ToTensor(),\n",
    "                                           transforms.Normalize((0.4915, 0.4823, 0.4468),\n",
    "                                                                (0.2470, 0.2435, 0.2616))\n",
    "                                       ]))\n",
    "\n",
    "label_map = {0:0,2:1}   # 把原来的标签映射过来\n",
    "\n",
    "class_names = ['airplane','bird']\n",
    "\n",
    "cifar2 = [(img,label_map[label]) for img,label in transformed_cifar10 \n",
    "          if label in [0,2]] # 把 0 2 图像及对应标签取出，赋值给img和映射的标签\n",
    "cifar2_val = [(img,label_map[label]) for img,label in transformed_cifar10_val \n",
    "              if label in [0,2]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0, Loss: 91.559662\n",
      "Epoch: 1, Loss: 85.259102\n",
      "Epoch: 2, Loss: 0.000000\n",
      "Epoch: 3, Loss: 36.619743\n",
      "Epoch: 4, Loss: 166.960663\n",
      "Epoch: 5, Loss: 0.000713\n",
      "Epoch: 6, Loss: 75.298431\n",
      "Epoch: 7, Loss: 0.007719\n",
      "Epoch: 8, Loss: 86.725845\n",
      "Epoch: 9, Loss: 55.219326\n",
      "Epoch: 10, Loss: 30.440813\n",
      "Epoch: 11, Loss: 4.653501\n",
      "Epoch: 12, Loss: 8.028452\n",
      "Epoch: 13, Loss: 94.359947\n",
      "Epoch: 14, Loss: 14.110658\n",
      "Epoch: 15, Loss: 96.579865\n",
      "Epoch: 16, Loss: 113.207542\n",
      "Epoch: 17, Loss: 0.000000\n",
      "Epoch: 18, Loss: 33.517952\n",
      "Epoch: 19, Loss: 84.194016\n",
      "Epoch: 20, Loss: 53.602997\n",
      "Epoch: 21, Loss: 9.633520\n",
      "Epoch: 22, Loss: 5.284300\n",
      "Epoch: 23, Loss: 40.722492\n",
      "Epoch: 24, Loss: 36.068230\n",
      "Epoch: 25, Loss: 105.185059\n",
      "Epoch: 26, Loss: 90.927505\n"
     ]
    }
   ],
   "source": [
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "model = nn.Sequential(\n",
    "            nn.Linear(3072,512),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(512,2),\n",
    "            nn.LogSoftmax(dim = 1))\n",
    "\n",
    "learning_rate = 1e-1\n",
    "optimizer = optim.SGD(model.parameters(),lr = learning_rate)\n",
    "loss_fn = nn.NLLLoss()\n",
    "n_epochs =100\n",
    "\n",
    "# model.cuda()\n",
    "for epoch in range(n_epochs):\n",
    "    for img,label in cifar2:                               # 一个样本一个样本输入，没有batch\n",
    "        out = model(img.view(-1).unsqueeze(0))\n",
    "        loss = loss_fn(out,torch.tensor([label]))\n",
    "        optimizer.zero_grad()                     # 将参数梯度置\n",
    "        loss.backward()                           # 反向传播，计算梯度\n",
    "        optimizer.step()                          # 更新参数，优化\n",
    "    print ('Epoch: %d, Loss: %f' %(epoch,float(loss)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = torch.untils.data.DataLoader(cifar2,batch_size = 64,shuffle = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0, Loss: 0.336142\n",
      "Epoch: 1, Loss: 0.510555\n",
      "Epoch: 2, Loss: 0.499815\n",
      "Epoch: 3, Loss: 0.563915\n",
      "Epoch: 4, Loss: 0.243007\n",
      "Epoch: 5, Loss: 0.383894\n",
      "Epoch: 6, Loss: 0.380500\n",
      "Epoch: 7, Loss: 0.365074\n",
      "Epoch: 8, Loss: 0.415585\n",
      "Epoch: 9, Loss: 0.443203\n",
      "Epoch: 10, Loss: 0.435805\n",
      "Epoch: 11, Loss: 0.230984\n",
      "Epoch: 12, Loss: 0.336917\n",
      "Epoch: 13, Loss: 0.464410\n",
      "Epoch: 14, Loss: 0.170214\n",
      "Epoch: 15, Loss: 0.200571\n",
      "Epoch: 16, Loss: 0.194849\n",
      "Epoch: 17, Loss: 0.602790\n",
      "Epoch: 18, Loss: 0.345542\n",
      "Epoch: 19, Loss: 0.508246\n",
      "Epoch: 20, Loss: 0.277048\n",
      "Epoch: 21, Loss: 0.589186\n",
      "Epoch: 22, Loss: 0.168732\n",
      "Epoch: 23, Loss: 0.188364\n",
      "Epoch: 24, Loss: 0.222695\n",
      "Epoch: 25, Loss: 0.332055\n",
      "Epoch: 26, Loss: 0.552662\n",
      "Epoch: 27, Loss: 0.337513\n",
      "Epoch: 28, Loss: 0.437499\n",
      "Epoch: 29, Loss: 0.203606\n",
      "Epoch: 30, Loss: 0.179525\n",
      "Epoch: 31, Loss: 0.209785\n",
      "Epoch: 32, Loss: 0.210296\n",
      "Epoch: 33, Loss: 0.173690\n",
      "Epoch: 34, Loss: 0.241678\n",
      "Epoch: 35, Loss: 0.247828\n",
      "Epoch: 36, Loss: 0.193895\n",
      "Epoch: 37, Loss: 0.103355\n",
      "Epoch: 38, Loss: 0.103795\n",
      "Epoch: 39, Loss: 0.071429\n",
      "Epoch: 40, Loss: 0.068258\n",
      "Epoch: 41, Loss: 0.117616\n",
      "Epoch: 42, Loss: 0.096629\n",
      "Epoch: 43, Loss: 0.123419\n",
      "Epoch: 44, Loss: 0.073161\n",
      "Epoch: 45, Loss: 0.257877\n",
      "Epoch: 46, Loss: 0.067850\n",
      "Epoch: 47, Loss: 0.188544\n",
      "Epoch: 48, Loss: 0.158699\n",
      "Epoch: 49, Loss: 0.138893\n",
      "Epoch: 50, Loss: 0.086552\n",
      "Epoch: 51, Loss: 0.018627\n",
      "Epoch: 52, Loss: 0.028799\n",
      "Epoch: 53, Loss: 0.053890\n",
      "Epoch: 54, Loss: 0.030658\n",
      "Epoch: 55, Loss: 0.070645\n",
      "Epoch: 56, Loss: 0.036666\n",
      "Epoch: 57, Loss: 0.067107\n",
      "Epoch: 58, Loss: 0.028115\n",
      "Epoch: 59, Loss: 0.048577\n",
      "Epoch: 60, Loss: 0.044142\n",
      "Epoch: 61, Loss: 0.033708\n",
      "Epoch: 62, Loss: 0.037068\n",
      "Epoch: 63, Loss: 0.058428\n",
      "Epoch: 64, Loss: 0.057785\n",
      "Epoch: 65, Loss: 0.033809\n",
      "Epoch: 66, Loss: 0.123783\n",
      "Epoch: 67, Loss: 0.046827\n",
      "Epoch: 68, Loss: 0.014269\n",
      "Epoch: 69, Loss: 0.010460\n",
      "Epoch: 70, Loss: 0.028332\n",
      "Epoch: 71, Loss: 0.134537\n",
      "Epoch: 72, Loss: 0.116367\n",
      "Epoch: 73, Loss: 0.022904\n",
      "Epoch: 74, Loss: 0.019132\n",
      "Epoch: 75, Loss: 0.010893\n",
      "Epoch: 76, Loss: 0.071083\n",
      "Epoch: 77, Loss: 0.034385\n",
      "Epoch: 78, Loss: 0.025175\n",
      "Epoch: 79, Loss: 0.026734\n",
      "Epoch: 80, Loss: 0.014235\n",
      "Epoch: 81, Loss: 0.017317\n",
      "Epoch: 82, Loss: 0.057643\n",
      "Epoch: 83, Loss: 0.006445\n",
      "Epoch: 84, Loss: 0.026685\n",
      "Epoch: 85, Loss: 0.023807\n",
      "Epoch: 86, Loss: 0.015605\n",
      "Epoch: 87, Loss: 0.030485\n",
      "Epoch: 88, Loss: 0.024027\n",
      "Epoch: 89, Loss: 0.011118\n",
      "Epoch: 90, Loss: 0.023453\n",
      "Epoch: 91, Loss: 0.013384\n",
      "Epoch: 92, Loss: 0.010460\n",
      "Epoch: 93, Loss: 0.017102\n",
      "Epoch: 94, Loss: 0.020958\n",
      "Epoch: 95, Loss: 0.020761\n",
      "Epoch: 96, Loss: 0.010123\n",
      "Epoch: 97, Loss: 0.004053\n",
      "Epoch: 98, Loss: 0.009519\n",
      "Epoch: 99, Loss: 0.010571\n"
     ]
    }
   ],
   "source": [
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(cifar2,batch_size = 64,shuffle = True)\n",
    "model = nn.Sequential(\n",
    "            nn.Linear(3072,512),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(512,2),\n",
    "            nn.LogSoftmax(dim = 1))\n",
    "\n",
    "learning_rate = 1e-2\n",
    "optimizer = optim.SGD(model.parameters(),lr = learning_rate)\n",
    "loss_fn = nn.NLLLoss()\n",
    "n_epochs =100\n",
    "\n",
    "# model.cuda()\n",
    "for epoch in range(n_epochs):\n",
    "    for imgs,labels in train_loader:                           # imgs 维度[16, 3, 32, 32]，batchsize 为64\n",
    "        batch_size = imgs.shape[0]\n",
    "        out = model(imgs.view(batch_size,-1))\n",
    "        loss = loss_fn(out,labels)\n",
    "        optimizer.zero_grad()                     # 将参数梯度置\n",
    "        loss.backward()                           # 反向传播，计算梯度\n",
    "        optimizer.step()                          # 更新参数，优化\n",
    "    print ('Epoch: %d, Loss: %f' %(epoch,float(loss)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.812000\n"
     ]
    }
   ],
   "source": [
    "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64,\n",
    "shuffle=False)\n",
    "correct =0\n",
    "total = 0\n",
    "with torch.no_grad():\n",
    "    for imgs,labels in val_loader:\n",
    "        batch_size = imgs.shape[0]\n",
    "        outputs = model(imgs.view(batch_size,-1))\n",
    "        _,predicted = torch.max(outputs,dim =1)\n",
    "        total += labels.shape[0]\n",
    "        correct += int((predicted == labels).sum())\n",
    "        \n",
    "print('Accuracy: %f' %( correct / total))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 3, 32, 32])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "imgs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 加深网络\n",
    "\n",
    "model3 = nn.Sequential(\n",
    "    nn.Linear(3072, 1024),\n",
    "    nn.Tanh(),\n",
    "    nn.Linear(1024, 512),\n",
    "    nn.Tanh(),\n",
    "    nn.Linear(512, 128),\n",
    "    nn.Tanh(),\n",
    "    nn.Linear(128, 2)\n",
    "    )\n",
    "\n",
    "loss_fn = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1574402, [1572864, 512, 1024, 2])"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "numel_list = [p.numel() for p in model.parameters() if p.requires_grad == True]\n",
    "\n",
    "sum(numel_list), numel_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
